/* Copyright 2020 The MLPerf Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MLPERF_DATASETS_SQUAD_UTILS_TYPES_H_
#define MLPERF_DATASETS_SQUAD_UTILS_TYPES_H_

#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature_util.h"
#include "tensorflow/core/platform/types.h"

namespace mlperf {
namespace mobile {

// GroundTruthRecord is equivlent to records in the ground truth tfrecord file.
struct GroundTruthRecord {
  explicit GroundTruthRecord(const tensorflow::tstring& record) {
    using string = google::protobuf::string;
    tensorflow::Example example;
    CHECK(example.ParseFromString(record));

    qas_id = tensorflow::GetFeatureValues<string>("qas_id", example)[0];

    auto token_values = tensorflow::GetFeatureValues<string>("tokens", example);
    tokens = std::vector<std::string>(token_values.begin(), token_values.end());

    auto word_values = tensorflow::GetFeatureValues<string>("words", example);
    words = std::vector<std::string>(word_values.begin(), word_values.end());

    auto answer_values =
        tensorflow::GetFeatureValues<string>("answers", example);
    answers =
        std::vector<std::string>(answer_values.begin(), answer_values.end());
  }

  // The words generated from the document by splitting spaces.
  std::vector<std::string> words;
  // Tokenized strings generated by from words.
  std::vector<std::string> tokens;
  // The list of ground truth answers.
  std::vector<std::string> answers;
  // The id of the question.
  std::string qas_id;
};

// SampleRecord is equivlent to records in the input tfrecord file.
struct SampleRecord {
  explicit SampleRecord(const tensorflow::tstring& record) {
    using int64 = google::protobuf::int64;
    using string = google::protobuf::string;
    tensorflow::Example example;
    CHECK(example.ParseFromString(record));
    // Data is stored as int64 in the tfrecord file so they need to be
    // converted to int32. Input_ids is in range [0, 30000).
    auto input_ids_values =
        tensorflow::GetFeatureValues<int64>("input_ids", example);
    input_ids_ =
        std::vector<int32_t>(input_ids_values.begin(), input_ids_values.end());
    // Input_mask is in range [0,1].
    auto input_mask_values =
        tensorflow::GetFeatureValues<int64>("input_mask", example);
    input_mask_ = std::vector<int32_t>(input_mask_values.begin(),
                                       input_mask_values.end());
    // Segment_ids is in range [0,1].
    auto segment_ids_values =
        tensorflow::GetFeatureValues<int64>("segment_ids", example);
    segment_ids_ = std::vector<int32_t>(segment_ids_values.begin(),
                                        segment_ids_values.end());
    // Get question id and tokens.
    qas_id_ = tensorflow::GetFeatureValues<string>("qas_id", example)[0];
    auto span_tokens_values =
        tensorflow::GetFeatureValues<string>("tokens", example);
    span_tokens_ = std::vector<std::string>(span_tokens_values.begin(),
                                            span_tokens_values.end());

    // Query tokens are included at the begining of the span_tokens_. But they
    // do not have values in token_index_map_ and token_is_max_context_;
    auto index_map_values =
        tensorflow::GetFeatureValues<int64>("token_to_orig_map", example);
    token_index_map_ =
        std::vector<int32_t>(index_map_values.begin(), index_map_values.end());

    auto is_max_context_values =
        tensorflow::GetFeatureValues<int64>("token_is_max_context", example);
    token_is_max_context_ = std::vector<int32_t>(is_max_context_values.begin(),
                                                 is_max_context_values.end());

    query_tokens_length_ = span_tokens_.size() - token_index_map_.size() - 1;
  }

  // input_ids_, input_mask_ and segment_ids are used as the inputs of the
  // MobileBert model in respective order.
  std::vector<int32_t> input_ids_;
  std::vector<int32_t> input_mask_;
  std::vector<int32_t> segment_ids_;
  // Following fields are not used for the model inputs, however, is necessary
  // to post-process the output of the model.
  std::string qas_id_;
  std::vector<std::string> span_tokens_;
  // The length of query tokens. Query token does not have value in
  // token_index_map_ and token_is_max_context_;
  int query_tokens_length_;
  // Map from span token index to original doc token index.
  std::vector<int32_t> token_index_map_;
  // A token may exists in many spans. We only want to consider the score with
  // "maximum context", means the minimum of its left and right context.
  std::vector<int32_t> token_is_max_context_;
};

// Store the output prediction of MobileBert model.
struct MobileBertPrediction {
  MobileBertPrediction(float* start_logit, uint32_t start_logit_size,
                       float* end_logit, int end_logit_size)
      : start_logit_(start_logit_size), end_logit_(end_logit_size) {
    std::copy(start_logit, start_logit + start_logit_size, start_logit_.data());
    std::copy(end_logit, end_logit + end_logit_size, end_logit_.data());
  }

  // Return the output data in binary format.
  std::vector<uint8_t> GetData() {
    uint32_t start_logit_byte = start_logit_.size() * 4;
    uint32_t end_logit_byte = end_logit_.size() * 4;
    std::vector<uint8_t> data(start_logit_byte + end_logit_byte);
    memcpy(data.data(), reinterpret_cast<char*>(start_logit_.data()),
           start_logit_byte);
    memcpy(data.data() + start_logit_byte,
           reinterpret_cast<char*>(end_logit_.data()), end_logit_byte);
    return data;
  }

  std::vector<float> start_logit_;
  std::vector<float> end_logit_;
};
}  // namespace mobile
}  // namespace mlperf

#endif  // MLPERF_DATASETS_SQUAD_UTILS_TYPES_H_
